{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import gurobipy as gp \n",
    "import numpy as np \n",
    "import networkx as nx\n",
    "from opt_w_predictions import *\n",
    "from sklearn.ensemble import RandomForestRegressor\n",
    "from joblib import Parallel, delayed\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from scipy.sparse import coo_matrix,csr_matrix\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define optimization oracles for shortest path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def shortest_path_oracle(sources, destinations, start_node, end_node, quiet=True): \n",
    "    n_nodes = length(nodes)\n",
    "    n_edges = length(sources)\n",
    "    d_feasibleregion = length(sources)\n",
    "    # Hard code sparse node-edge incidence matrix!\n",
    "    I_vec = np.hstack([sources,destinations])\n",
    "    J_vec = np.hstack([range(n_edges),range(n_edges)])\n",
    "    V_vec = np.hstack([-np.ones(n_edges), np.ones(n_edges)])\n",
    "\n",
    "    A_mat = coo_matrix((V_vec, (I_vec, J_vec)))\n",
    "    A_mat = csr_matrix(A_mat)\n",
    "    bvec = np.zeros(n_nodes)\n",
    "    bvec[start_node] = -1; bvec[end_node] = 1; \n",
    "    m = gp.Model()\n",
    "    if quiet: m.setParam(\"OutputFlag\", 0)\n",
    "    w = m.addMVar(n_edges, lb= 0, ub = 1)\n",
    "    m.addConstrs(A_mat[i,:] @ w == bvec[i] for i in range(n_nodes) )\n",
    "    def local_sp_oracle_jump(c):\n",
    "        m.setObjective( c @ w, gp.GRB.MINIMIZE)\n",
    "        m.optimize()\n",
    "        z_ast = m.objVal\n",
    "        w_ast = np.asarray([w[i].X for i in range(len(c))]).flatten()\n",
    "        return [z_ast, w_ast]\n",
    "    return local_sp_oracle_jump\n",
    "\n",
    "\n",
    "def serializable_shortest_path_oracle(sources, destinations, start_node, end_node, quiet=True): \n",
    "    ''' Precomputing the model doesn't precompute for gurobi \n",
    "    '''\n",
    "    n_nodes = length(nodes)\n",
    "    n_edges = length(sources)\n",
    "    d_feasibleregion = length(sources)\n",
    "    # Hard code sparse node-edge incidence matrix!\n",
    "    I_vec = np.hstack([sources,destinations])\n",
    "    J_vec = np.hstack([range(n_edges),range(n_edges)])\n",
    "    V_vec = np.hstack([-np.ones(n_edges), np.ones(n_edges)])\n",
    "\n",
    "    A_mat = coo_matrix((V_vec, (I_vec, J_vec)))\n",
    "    A_mat = csr_matrix(A_mat)\n",
    "    bvec = np.zeros(n_nodes)\n",
    "    bvec[start_node] = -1; bvec[end_node] = 1; \n",
    "    \n",
    "    def serializable_local_sp_oracle_jump(c):\n",
    "        m = gp.Model()\n",
    "        if quiet: m.setParam(\"OutputFlag\", 0)\n",
    "        w = m.addMVar(n_edges, lb= 0, ub = 1)\n",
    "        m.addConstrs(A_mat[i,:] @ w == bvec[i] for i in range(n_nodes) )\n",
    "        m.setObjective( c @ w, gp.GRB.MINIMIZE)\n",
    "        m.optimize()\n",
    "        z_ast = m.objVal\n",
    "        w_ast = np.asarray([w[i].X for i in range(len(c))]).flatten()\n",
    "        return [z_ast, w_ast]\n",
    "    return serializable_local_sp_oracle_jump"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "From the SPO paper: \n",
    "In the following set of experiments on the shortest path problem we described, we fix the number of features at p = 5 throughout and, as previously mentioned, use a 5 × 5 grid network, which implies that d = 40. Hence, in total there are pd = 200 parameters to estimate."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using license file /Users/connorlawless/gurobi.lic\n",
      "Academic license - for non-commercial use only\n",
      "{(0, 0): 0, (0, 1): 1, (0, 2): 2, (0, 3): 3, (0, 4): 4, (1, 0): 5, (1, 1): 6, (1, 2): 7, (1, 3): 8, (1, 4): 9, (2, 0): 10, (2, 1): 11, (2, 2): 12, (2, 3): 13, (2, 4): 14, (3, 0): 15, (3, 1): 16, (3, 2): 17, (3, 3): 18, (3, 4): 19, (4, 0): 20, (4, 1): 21, (4, 2): 22, (4, 3): 23, (4, 4): 24}\n",
      "Using license file /Users/connorlawless/gurobi.lic\n",
      "Academic license - for non-commercial use only\n"
     ]
    }
   ],
   "source": [
    "gurobiEnvOracle = gp.Env()\n",
    "[sources, destinations, scalar_nodes, tuple_nodes] = convert_grid_to_list(5, 5)\n",
    "grid_dim = 5\n",
    "# start_node = 0; end_node = grid_dim**2-1\n",
    "nodes = np.unique(list(set(sources).union(set(destinations))))\n",
    "d_feasibleregion = length(sources)\n",
    "start_node = scalar_nodes[(0,0)]\n",
    "end_node = scalar_nodes[(4,4)]\n",
    "sp_oracle = shortest_path_oracle(sources, destinations, start_node, end_node)\n",
    "para_sp_oracle = serializable_shortest_path_oracle(sources, destinations, start_node, end_node)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "    # Generate Data\n",
    "p_features=5\n",
    "B_true = rand(1,0.5,size=(d_feasibleregion, p_features))\n",
    "n_train = 500\n",
    "n_test = 10000; n_holdout = 1000 \n",
    "polykernel_degree_vec = [1, 2, 4, 6, 8]\n",
    "polykernel_noise_half_width_vec = [0, 0.5]\n",
    "polykernel_degree = polykernel_degree_vec[-1]\n",
    "polykernel_noise_half_width = polykernel_noise_half_width_vec[1]\n",
    "\n",
    "(X_train, c_train) = generate_poly_kernel_data_simple(B_true, n_train, polykernel_degree, polykernel_noise_half_width)\n",
    "(X_validation, c_validation) = generate_poly_kernel_data_simple(B_true, n_holdout, polykernel_degree, polykernel_noise_half_width)\n",
    "(X_test, c_test) = generate_poly_kernel_data_simple(B_true, n_test, polykernel_degree, polykernel_noise_half_width)\n",
    "\n",
    "# Add intercept in the first row of X\n",
    "# X is p x N \n",
    "X_train = vcat(ones(1,n_train), X_train)\n",
    "X_validation = vcat(ones(1,n_holdout), X_validation)\n",
    "X_test = vcat(ones(1,n_test), X_test)\n",
    "\n",
    "# Get dimensions of input\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def feasible_least_squares(regressor, x_star_regr, mu, c_train, X_train, sp_oracle, random_regr=False):\n",
    "    '''\n",
    "    One-step plug in \n",
    "    regressor: type of Scikit-Learn regressor\n",
    "    x_star_regr: x^* - \\hat x^* \n",
    "    c_train: training data, cost vector realizations\n",
    "    X_train: training data \n",
    "    '''\n",
    "    # Mix decision weights with uniform \n",
    "    weights = (mu*x_star_regr + (1-mu))*np.ones(c_train.shape)\n",
    "    weighted_predictors = get_weighted_predictors(regressor, c_train, X_train, weights, random_regr=random_regr)\n",
    "    [reweighted_regrets, reweighted_x_star_regr] = get_regret(weighted_predictors,\n",
    "            X_train, c_train, X_test, c_test,sp_oracle, quiet=False)\n",
    "    return [weighted_predictors, reweighted_regrets, reweighted_x_star_regr]\n",
    "    \n",
    "def run_replication_over_weights(data_params, X_test, c_test, mixture_weights, \n",
    "                                 regressor, sp_oracle, random_regr=False): \n",
    "    '''\n",
    "    Run a replication under fixed data parameters \n",
    "    Assume pre-generated test dataset (to reduce noise in test evaluation)\n",
    "    '''\n",
    "    \n",
    "    [n_train, n_test, n_holdout, polykernel_degree,polykernel_noise_half_width,B_true] = data_params\n",
    "    [X_train, c_train,X_validation, c_validation] = generate_data(n_train, \n",
    "        n_test, n_holdout, polykernel_degree,polykernel_noise_half_width,B_true,gen_test=False)\n",
    "    \n",
    "    # learn initial predictor \n",
    "    predictors = get_weighted_predictors(regressor, c_train, X_train, random_regr=random_regr )\n",
    "    [regrets, x_star_regr] = get_regret(predictors, X_train, c_train, X_test, c_test, sp_oracle)\n",
    "    print(np.mean(np.sqrt(np.square(regrets))))\n",
    "    print(x_star_regr.mean(axis=1))\n",
    "    n_wghts = len(mixture_weights)\n",
    "\n",
    "    original_regrets = np.mean(np.sqrt(np.square(regrets))); \n",
    "    reweighted_mean_regrets = np.zeros(n_wghts)\n",
    "    for k in range(n_wghts): \n",
    "        [weighted_predictors, reweighted_regrets, reweighted_x_star_regr] = feasible_least_squares(\n",
    "            regressor, x_star_regr, mixture_weights[k], c_train, X_train,sp_oracle,random_regr=random_regr)\n",
    "#         print('reweighting ', mixture_weights[k])\n",
    "        reweighted_mean_regrets[k] = np.mean(np.sqrt(np.square(reweighted_regrets)))\n",
    "    return [original_regrets, reweighted_mean_regrets]\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "    [X_train, c_train,X_validation, c_validation] = generate_data(n_train, \n",
    "        n_test, n_holdout, polykernel_degree,polykernel_noise_half_width,B_true,gen_test=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(40, 500)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sgd_tr_instance = []\n",
    "\n",
    "for i in range(len(c_train.shape))\n",
    "\n",
    "c_train.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train predictors "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "run_replication_over_weights(data_params, X_test, c_test, mixture_weights, \n",
    "                                 regressor, sp_oracle, random_regr=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(6, 10000)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(40, 500)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "[X_train, c_train,X_validation, c_validation, X_train, X_test] = generate_data(n_train, \n",
    "            n_test, n_holdout, polykernel_degree,polykernel_noise_half_width,B_true,gen_test=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((40, 10000), (40, 10000))"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_test.shape, c_test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n , 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
      "[Parallel(n_jobs=-1)]: Done   1 tasks      | elapsed:   34.6s\n",
      "[Parallel(n_jobs=-1)]: Done   2 out of  12 | elapsed:   34.7s remaining:  2.9min\n",
      "[Parallel(n_jobs=-1)]: Done   3 out of  12 | elapsed:   34.7s remaining:  1.7min\n",
      "[Parallel(n_jobs=-1)]: Done   4 out of  12 | elapsed:   34.8s remaining:  1.2min\n",
      "[Parallel(n_jobs=-1)]: Done   5 out of  12 | elapsed:   34.9s remaining:   48.9s\n",
      "[Parallel(n_jobs=-1)]: Done   6 out of  12 | elapsed:   35.0s remaining:   35.0s\n",
      "[Parallel(n_jobs=-1)]: Done   7 out of  12 | elapsed:   35.0s remaining:   25.0s\n",
      "[Parallel(n_jobs=-1)]: Done   8 out of  12 | elapsed:   35.2s remaining:   17.6s\n",
      "[Parallel(n_jobs=-1)]: Done   9 out of  12 | elapsed:   54.0s remaining:   18.0s\n",
      "[Parallel(n_jobs=-1)]: Done  10 out of  12 | elapsed:   54.1s remaining:   10.8s\n",
      "[Parallel(n_jobs=-1)]: Done  12 out of  12 | elapsed:   54.1s remaining:    0.0s\n",
      "[Parallel(n_jobs=-1)]: Done  12 out of  12 | elapsed:   54.1s finished\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n , 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
      "[Parallel(n_jobs=-1)]: Done   1 tasks      | elapsed:   32.4s\n",
      "[Parallel(n_jobs=-1)]: Done   2 out of  12 | elapsed:   32.5s remaining:  2.7min\n",
      "[Parallel(n_jobs=-1)]: Done   3 out of  12 | elapsed:   32.6s remaining:  1.6min\n",
      "[Parallel(n_jobs=-1)]: Done   4 out of  12 | elapsed:   32.6s remaining:  1.1min\n",
      "[Parallel(n_jobs=-1)]: Done   5 out of  12 | elapsed:   32.7s remaining:   45.7s\n",
      "[Parallel(n_jobs=-1)]: Done   6 out of  12 | elapsed:   32.7s remaining:   32.7s\n",
      "[Parallel(n_jobs=-1)]: Done   7 out of  12 | elapsed:   32.8s remaining:   23.4s\n",
      "[Parallel(n_jobs=-1)]: Done   8 out of  12 | elapsed:   32.9s remaining:   16.5s\n",
      "[Parallel(n_jobs=-1)]: Done   9 out of  12 | elapsed:   50.8s remaining:   16.9s\n",
      "[Parallel(n_jobs=-1)]: Done  10 out of  12 | elapsed:   50.8s remaining:   10.2s\n",
      "[Parallel(n_jobs=-1)]: Done  12 out of  12 | elapsed:   50.9s remaining:    0.0s\n",
      "[Parallel(n_jobs=-1)]: Done  12 out of  12 | elapsed:   50.9s finished\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n , 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
      "[Parallel(n_jobs=-1)]: Done   1 tasks      | elapsed:   33.5s\n",
      "[Parallel(n_jobs=-1)]: Done   2 out of  12 | elapsed:   33.7s remaining:  2.8min\n",
      "[Parallel(n_jobs=-1)]: Done   3 out of  12 | elapsed:   33.7s remaining:  1.7min\n",
      "[Parallel(n_jobs=-1)]: Done   4 out of  12 | elapsed:   33.7s remaining:  1.1min\n",
      "[Parallel(n_jobs=-1)]: Done   5 out of  12 | elapsed:   33.7s remaining:   47.2s\n",
      "[Parallel(n_jobs=-1)]: Done   6 out of  12 | elapsed:   33.8s remaining:   33.8s\n",
      "[Parallel(n_jobs=-1)]: Done   7 out of  12 | elapsed:   33.8s remaining:   24.2s\n",
      "[Parallel(n_jobs=-1)]: Done   8 out of  12 | elapsed:   34.0s remaining:   17.0s\n",
      "[Parallel(n_jobs=-1)]: Done   9 out of  12 | elapsed:   51.2s remaining:   17.1s\n",
      "[Parallel(n_jobs=-1)]: Done  10 out of  12 | elapsed:   51.3s remaining:   10.3s\n",
      "[Parallel(n_jobs=-1)]: Done  12 out of  12 | elapsed:   51.4s remaining:    0.0s\n",
      "[Parallel(n_jobs=-1)]: Done  12 out of  12 | elapsed:   51.4s finished\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n , 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
      "[Parallel(n_jobs=-1)]: Done   1 tasks      | elapsed:   32.6s\n",
      "[Parallel(n_jobs=-1)]: Done   2 out of  12 | elapsed:   32.7s remaining:  2.7min\n",
      "[Parallel(n_jobs=-1)]: Done   3 out of  12 | elapsed:   32.7s remaining:  1.6min\n",
      "[Parallel(n_jobs=-1)]: Done   4 out of  12 | elapsed:   32.8s remaining:  1.1min\n",
      "[Parallel(n_jobs=-1)]: Done   5 out of  12 | elapsed:   32.8s remaining:   45.9s\n",
      "[Parallel(n_jobs=-1)]: Done   6 out of  12 | elapsed:   32.9s remaining:   32.9s\n",
      "[Parallel(n_jobs=-1)]: Done   7 out of  12 | elapsed:   33.0s remaining:   23.6s\n",
      "[Parallel(n_jobs=-1)]: Done   8 out of  12 | elapsed:   33.1s remaining:   16.5s\n",
      "[Parallel(n_jobs=-1)]: Done   9 out of  12 | elapsed:   50.5s remaining:   16.8s\n",
      "[Parallel(n_jobs=-1)]: Done  10 out of  12 | elapsed:   50.7s remaining:   10.1s\n",
      "[Parallel(n_jobs=-1)]: Done  12 out of  12 | elapsed:   50.7s remaining:    0.0s\n",
      "[Parallel(n_jobs=-1)]: Done  12 out of  12 | elapsed:   50.7s finished\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n , 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
      "[Parallel(n_jobs=-1)]: Done   1 tasks      | elapsed:   31.3s\n",
      "[Parallel(n_jobs=-1)]: Done   2 out of  12 | elapsed:   31.3s remaining:  2.6min\n",
      "[Parallel(n_jobs=-1)]: Done   3 out of  12 | elapsed:   31.5s remaining:  1.6min\n",
      "[Parallel(n_jobs=-1)]: Done   4 out of  12 | elapsed:   31.5s remaining:  1.0min\n",
      "[Parallel(n_jobs=-1)]: Done   5 out of  12 | elapsed:   31.5s remaining:   44.2s\n",
      "[Parallel(n_jobs=-1)]: Done   6 out of  12 | elapsed:   31.6s remaining:   31.6s\n",
      "[Parallel(n_jobs=-1)]: Done   7 out of  12 | elapsed:   31.6s remaining:   22.6s\n",
      "[Parallel(n_jobs=-1)]: Done   8 out of  12 | elapsed:   31.8s remaining:   15.9s\n",
      "[Parallel(n_jobs=-1)]: Done   9 out of  12 | elapsed:   47.5s remaining:   15.8s\n",
      "[Parallel(n_jobs=-1)]: Done  10 out of  12 | elapsed:   47.7s remaining:    9.5s\n",
      "[Parallel(n_jobs=-1)]: Done  12 out of  12 | elapsed:   47.8s remaining:    0.0s\n",
      "[Parallel(n_jobs=-1)]: Done  12 out of  12 | elapsed:   47.8s finished\n"
     ]
    }
   ],
   "source": [
    "import pickle \n",
    "#N_s = [ 100, 250, 500, 1000, 1500, 2000]\n",
    "N_s = [ 100]\n",
    "\n",
    "n_wghts = 5; mixture_weights = np.linspace(1.0/n_wghts,0.99,n_wghts)\n",
    "n_reps = 12\n",
    "\n",
    "# initialize choice of regressor\n",
    "regressor = LinearRegression\n",
    "random_regr=False\n",
    "n_misp=len(polykernel_degree_vec)\n",
    "original_regrets_ = np.zeros([n_misp,len(N_s), n_reps]); reweighted_regrets_ = np.zeros([n_misp,len(N_s), n_reps, n_wghts])\n",
    "\n",
    "\n",
    "for misspec_ind, degree in enumerate(polykernel_degree_vec): \n",
    "    polykernel_degree = degree \n",
    "    for ind_n, en_train in enumerate(N_s): \n",
    "        print('n ,', en_train)\n",
    "        data_params = [en_train, n_test, n_holdout, polykernel_degree,polykernel_noise_half_width,B_true]\n",
    "        [X_train, c_train,X_validation, c_validation, X_train, X_test] = generate_data(n_train, \n",
    "            n_test, n_holdout, polykernel_degree,polykernel_noise_half_width,B_true,gen_test=True)\n",
    "\n",
    "        # Each replication returns [original_regrets, reweighted_mean_regrets (n_wghts) ]\n",
    "        res = Parallel(n_jobs=-1, verbose=20)(delayed(run_replication_over_weights)(\n",
    "            data_params, X_test, c_test, mixture_weights, regressor, para_sp_oracle, random_regr=random_regr) for k in range(n_reps))\n",
    "\n",
    "    #     [original_regrets, reweighted_regrets] = run_replication_over_weights(data_params, mixture_weights, regressor, sp_oracle, random_regr=random_regr)\n",
    "        original_regrets_[misspec_ind,ind_n,:] = np.asarray([res[i][0] for i in range(n_reps)])\n",
    "        reweighted_regrets_[misspec_ind,ind_n,:,:] = np.asarray([res[i][1] for i in range(n_reps)])\n",
    "\n",
    "        exp_name = 'linear_regression_misspec_'+str(misspec_ind)\n",
    "        res_ = dict({'original_regrets_': original_regrets_, 'reweighted_regrets_':reweighted_regrets_, 'data_params':data_params, 'N_s':N_s, 'mixture_weights':mixture_weights})\n",
    "        pickle.dump(res_, open(exp_name+'-res_may18.p', 'wb'))\n",
    "\n",
    "\n",
    "#     print('original regret, ', original_regrets)\n",
    "#     print('reweighted regrets, ', reweighted_regrets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_name = 'linear_regression_'\n",
    "res_ = dict({'original_regrets_': original_regrets_, 'reweighted_regrets_':reweighted_regrets_, 'data_params':data_params, 'N_s':N_s, 'mixture_weights':mixture_weights})\n",
    "pickle.dump(res_, open(exp_name+'-res.p', 'wb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "0\n",
      "reweighting  0.2\n",
      "4743.91202177076\n",
      "4854.331112751639\n",
      "0\n",
      "reweighting  0.4\n",
      "4743.91202177076\n",
      "4897.321988720906\n",
      "0\n",
      "reweighting  0.6000000000000001\n",
      "4743.91202177076\n",
      "5085.711101314495\n",
      "0\n",
      "reweighting  0.8\n",
      "4743.91202177076\n",
      "5334.150144065382\n",
      "0\n",
      "reweighting  1.0\n",
      "4743.91202177076\n",
      "6843.877078325886\n"
     ]
    }
   ],
   "source": [
    "n_train = 250 \n",
    "regressor = RandomForestRegressor\n",
    "# [n_train, n_test, n_holdout, polykernel_degree,polykernel_noise_half_width,B_true] = data_params\n",
    "[X_train, c_train,X_validation, c_validation, X_test, c_test] = generate_data(n_train, \n",
    "    n_test, n_holdout, polykernel_degree,polykernel_noise_half_width,B_true,gen_test=True)\n",
    "predictors = get_weighted_predictors(regressor, c_train, X_train)\n",
    "[regrets, x_star_regr] = get_regret(predictors, X_train, c_train, X_test, c_test, sp_oracle)\n",
    "print(np.mean(np.sqrt(np.square(regrets))))\n",
    "n_wghts = len(mixture_weights)\n",
    "\n",
    "for k in range(n_wghts): \n",
    "    [weighted_predictors, reweighted_regrets, reweighted_x_star_regr] = feasible_least_squares(regressor, x_star_regr, mixture_weights[k], c_train, X_train,sp_oracle)\n",
    "    print('reweighting ', mixture_weights[k])\n",
    "    \n",
    "    print(np.mean(np.sqrt(np.square(reweighted_regrets))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "reweighting  0.0\n",
      "3738.462547296788\n",
      "3856.9007494036164\n",
      "0\n",
      "reweighting  0.25\n",
      "3738.462547296788\n",
      "3288.022802404781\n",
      "0\n",
      "reweighting  0.5\n",
      "3738.462547296788\n",
      "3635.9303441225484\n",
      "0\n",
      "reweighting  0.75\n",
      "3738.462547296788\n",
      "3770.76793113731\n",
      "0\n",
      "reweighting  1.0\n",
      "3738.462547296788\n",
      "2358.8529892644383\n"
     ]
    }
   ],
   "source": [
    "\n",
    "    \n",
    "# Feasible predict-then-optimize \n",
    "n_wghts = 5; mixture_weights = np.linspace(0,1,n_wghts)\n",
    "for k in range(n_wghts): \n",
    "    weights = mixture_weights[k]*np.abs(x_star_regr) + (1- mixture_weights[k])*np.ones(c_train.shape)\n",
    "    weighted_predictors = get_weighted_predictors(RandomForestRegressor, c_train, X_train,weights)\n",
    "    [reweighted_regrets, reweighted_x_star_regr] = get_regret(weighted_predictors,X_train, c_train, X_test, c_test,quiet=False)\n",
    "\n",
    "    print('reweighting ', mixture_weights[k])\n",
    "    print(np.mean(np.sqrt(np.square(regrets))))\n",
    "    print(np.mean(np.sqrt(np.square(reweighted_regrets))))\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7041773923402886\n",
      "4.70906757162312\n"
     ]
    }
   ],
   "source": [
    "print(np.mean(np.sqrt(np.square(regrets))))\n",
    "print(np.mean(np.sqrt(np.square(reweighted_regrets))))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_star_regr.mean(axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "regrets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Get decision risks on previous dataset "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Run one instance of the shortest path experiment. Uses the reformulation approach.\n",
    "`num_lambda` is the number of lambdas used on the grid for each method with regularization.\n",
    "Returns a replication_results struct.\n",
    "\"\"\"\n",
    "function shortest_path_replication(grid_dim,\n",
    "    n_train, n_holdout, n_test,\n",
    "    p_features, polykernel_degree, polykernel_noise_half_width;\n",
    "    num_lambda = 10, lambda_max = missing, lambda_min_ratio = 0.0001, regularization = :ridge,\n",
    "    , gurobiEnvReform = Gurobi.Env(), different_validation_losses = false,\n",
    "    include_rf = true)\n",
    "\n",
    "    # Get oracle\n",
    "    gurobiEnvOracle = Gurobi.Env()\n",
    "    sources, destinations = convert_grid_to_list(grid_dim, grid_dim)\n",
    "    d_feasibleregion = length(sources)\n",
    "    sp_oracle = sp_flow_jump_setup(sources, destinations, 1, grid_dim^2; gurobiEnv = gurobiEnvOracle)\n",
    "    sp_graph = shortest_path_graph(sources = sources, destinations = destinations,\n",
    "        start_node = 1, end_node = grid_dim^2, acyclic = true)\n",
    "\n",
    "\n",
    "    # Generate Data\n",
    "    B_true = rand(Bernoulli(0.5), d_feasibleregion, p_features)\n",
    "\n",
    "    (X_train, c_train) = generate_poly_kernel_data_simple(B_true, n_train, polykernel_degree, polykernel_noise_half_width)\n",
    "    (X_validation, c_validation) = generate_poly_kernel_data_simple(B_true, n_holdout, polykernel_degree, polykernel_noise_half_width)\n",
    "    (X_test, c_test) = generate_poly_kernel_data_simple(B_true, n_test, polykernel_degree, polykernel_noise_half_width)\n",
    "    \n",
    "    # Add intercept in the first row of X\n",
    "    X_train = vcat(ones(1,n_train), X_train)\n",
    "    X_validation = vcat(ones(1,n_holdout), X_validation)\n",
    "    X_test = vcat(ones(1,n_test), X_test)\n",
    "\n",
    "    # Get Hamming labels\n",
    "    (z_train, w_train) = oracle_dataset(c_train, sp_oracle)\n",
    "    (z_validation, w_validation) = oracle_dataset(c_validation, sp_oracle)\n",
    "    (z_test, w_test) = oracle_dataset(c_test, sp_oracle)\n",
    "\n",
    "    c_ham_train = ones(d_feasibleregion, n_train) - w_train\n",
    "    c_ham_validation = ones(d_feasibleregion, n_holdout) - w_validation\n",
    "    c_ham_test = ones(d_feasibleregion, n_test) - w_test\n",
    "\n",
    "    # Put train + validation together\n",
    "    X_both = hcat(X_train, X_validation)\n",
    "    c_both = hcat(c_train, c_validation)\n",
    "    c_ham_both = hcat(c_ham_train, c_ham_validation)\n",
    "    train_ind = collect(1:n_train)\n",
    "    validation_ind = collect((n_train + 1):(n_train + n_holdout))\n",
    "\n",
    "    # Set validation losses\n",
    "    if different_validation_losses\n",
    "        spo_plus_val_loss = :spo_loss\n",
    "        ls_val_loss = :least_squares_loss\n",
    "        ssvm_val_loss = :hamming_loss\n",
    "        absolute_val_loss = :absolute_loss\n",
    "        huber_val_loss = :huber_loss\n",
    "    else\n",
    "        spo_plus_val_loss = :spo_loss\n",
    "        ls_val_loss = :spo_loss\n",
    "        ssvm_val_loss = :spo_loss\n",
    "        absolute_val_loss = :spo_loss\n",
    "        huber_val_loss = :spo_loss\n",
    "    end\n",
    "\n",
    "    ### Algorithms ###\n",
    "\n",
    "    # SPO+\n",
    "    best_B_SPOplus, best_lambda_SPOplus = validation_set_alg(X_both, c_both, sp_oracle; sp_graph = sp_graph,\n",
    "        train_ind = train_ind, validation_ind = validation_ind,\n",
    "        val_alg_parms = val_parms(algorithm_type = :sp_spo_plus_reform, validation_loss = spo_plus_val_loss),\n",
    "        path_alg_parms = reformulation_path_parms(num_lambda = num_lambda, lambda_max = lambda_max, regularization = regularization,\n",
    "            gurobiEnv = gurobiEnvReform, lambda_min_ratio = lambda_min_ratio, algorithm_type = :SPO_plus))\n",
    "\n",
    "    # Least squares\n",
    "    best_B_leastSquares, best_lambda_leastSquares = validation_set_alg(X_both, c_both, sp_oracle; sp_graph = sp_graph,\n",
    "        train_ind = train_ind, validation_ind = validation_ind,\n",
    "        val_alg_parms = val_parms(algorithm_type = :ls_jump, validation_loss = ls_val_loss),\n",
    "        path_alg_parms = reformulation_path_parms(num_lambda = num_lambda, lambda_max = lambda_max, regularization = regularization,\n",
    "            gurobiEnv = gurobiEnvReform, lambda_min_ratio = lambda_min_ratio, algorithm_type = :leastSquares))\n",
    "\n",
    "    # SSVM Hamming\n",
    "    best_B_SSVM, best_lambda_SSVM = validation_set_alg(X_both, c_ham_both, sp_oracle; sp_graph = sp_graph,\n",
    "        train_ind = train_ind, validation_ind = validation_ind,\n",
    "        val_alg_parms = val_parms(algorithm_type = :sp_spo_plus_reform, validation_loss = ssvm_val_loss),\n",
    "        path_alg_parms = reformulation_path_parms(num_lambda = num_lambda, lambda_max = lambda_max, regularization = regularization,\n",
    "            gurobiEnv = gurobiEnvReform, lambda_min_ratio = lambda_min_ratio, algorithm_type = :SSVM_hamming))\n",
    "\n",
    "    # RF\n",
    "    if include_rf\n",
    "        rf_mods = train_random_forests_po(X_train, c_train;\n",
    "            rf_alg_parms = rf_parms())\n",
    "    end\n",
    "\n",
    "    # Absolute\n",
    "    best_B_Absolute, best_lambda_Absolute = validation_set_alg(X_both, c_both, sp_oracle; sp_graph = sp_graph,\n",
    "        train_ind = train_ind, validation_ind = validation_ind,\n",
    "        val_alg_parms = val_parms(algorithm_type = :ls_jump, validation_loss = absolute_val_loss),\n",
    "        path_alg_parms = reformulation_path_parms(num_lambda = num_lambda, lambda_max = lambda_max, regularization = regularization,\n",
    "            po_loss_function = :absolute, gurobiEnv = gurobiEnvReform, lambda_min_ratio = lambda_min_ratio, algorithm_type = :Absolute))\n",
    "\n",
    "    # Huber\n",
    "    best_B_fake, best_lambda_fake = validation_set_alg(X_both, c_both, sp_oracle; sp_graph = sp_graph,\n",
    "        train_ind = train_ind, validation_ind = validation_ind,\n",
    "        val_alg_parms = val_parms(algorithm_type = :ls_jump, validation_loss = ls_val_loss),\n",
    "        path_alg_parms = reformulation_path_parms(lambda_max = 0.0001, regularization = regularization,\n",
    "            num_lambda = 2, gurobiEnv = gurobiEnvReform, algorithm_type = :Huber_LS_fake))\n",
    "\n",
    "    fake_ls_list = abs.(vec(c_train - best_B_fake*X_train))\n",
    "    delta_from_fake_ls = median(fake_ls_list)\n",
    "\n",
    "    best_B_Huber, best_lambda_Huber = validation_set_alg(X_both, c_both, sp_oracle; sp_graph = sp_graph,\n",
    "        train_ind = train_ind, validation_ind = validation_ind,\n",
    "        val_alg_parms = val_parms(algorithm_type = :ls_jump, validation_loss = huber_val_loss),\n",
    "        path_alg_parms = reformulation_path_parms(num_lambda = num_lambda, lambda_max = lambda_max, regularization = regularization, po_loss_function = :huber,\n",
    "            huber_delta = delta_from_fake_ls, lambda_min_ratio = lambda_min_ratio, gurobiEnv = gurobiEnvReform, algorithm_type = :Huber))\n",
    "\n",
    "\n",
    "    # Baseline\n",
    "    c_bar_train = mean(c_train, dims=2)\n",
    "\n",
    "\n",
    "    ### Populate final results ###\n",
    "    final_results = replication_results()\n",
    "\n",
    "    final_results.SPOplus_spoloss_test = spo_loss(best_B_SPOplus, X_test, c_test, sp_oracle)\n",
    "    final_results.LS_spoloss_test = spo_loss(best_B_leastSquares, X_test, c_test, sp_oracle)\n",
    "    final_results.SSVM_spoloss_test = spo_loss(best_B_SSVM, X_test, c_test, sp_oracle)\n",
    "\n",
    "    if include_rf\n",
    "        rf_preds_test = predict_random_forests_po(rf_mods, X_test)\n",
    "        final_results.RF_spoloss_test = spo_loss(Matrix(1.0I, d_feasibleregion, d_feasibleregion), rf_preds_test, c_test, sp_oracle)\n",
    "    else\n",
    "        final_results.RF_spoloss_test = missing\n",
    "    end\n",
    "\n",
    "    final_results.Absolute_spoloss_test = spo_loss(best_B_Absolute, X_test, c_test, sp_oracle)\n",
    "    final_results.Huber_spoloss_test = spo_loss(best_B_Huber, X_test, c_test, sp_oracle)\n",
    "\n",
    "    c_bar_test_preds = zeros(d_feasibleregion, n_test)\n",
    "    for i = 1:n_test\n",
    "        c_bar_test_preds[:, i] = c_bar_train\n",
    "    end\n",
    "    final_results.Baseline_spoloss_test = spo_loss(Matrix(1.0I, d_feasibleregion, d_feasibleregion), c_bar_test_preds, c_test, sp_oracle)\n",
    "\n",
    "\n",
    "    # Now Hamming\n",
    "    final_results.SPOplus_hammingloss_test = spo_loss(best_B_SPOplus, X_test, c_ham_test, sp_oracle)\n",
    "    final_results.LS_hammingloss_test = spo_loss(best_B_leastSquares, X_test, c_ham_test, sp_oracle)\n",
    "    final_results.SSVM_hammingloss_test = spo_loss(best_B_SSVM, X_test, c_ham_test, sp_oracle)\n",
    "\n",
    "    if include_rf\n",
    "        final_results.RF_hammingloss_test = spo_loss(Matrix(1.0I, d_feasibleregion, d_feasibleregion), rf_preds_test, c_ham_test, sp_oracle)\n",
    "    else\n",
    "        final_results.RF_hammingloss_test = missing\n",
    "    end\n",
    "\n",
    "    final_results.Absolute_hammingloss_test = spo_loss(best_B_Absolute, X_test, c_ham_test, sp_oracle)\n",
    "    final_results.Huber_hammingloss_test = spo_loss(best_B_Huber, X_test, c_ham_test, sp_oracle)\n",
    "    final_results.Baseline_hammingloss_test = spo_loss(Matrix(1.0I, d_feasibleregion, d_feasibleregion), c_bar_test_preds, c_ham_test, sp_oracle)\n",
    "\n",
    "    final_results.zstar_avg_test = mean(z_test)\n",
    "\n",
    "    return final_results\n",
    "end"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
